{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1eed6edf-66e1-46e9-a9ac-296568a28ff9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from sklearn.datasets import load_iris\n",
    "from sklearn.model_selection import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "24220559-2b02-4c26-ae7e-eaa69f3abaff",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载Iris数据集\n",
    "iris = load_iris()\n",
    "data = iris.data\n",
    "target = iris.target\n",
    "# 将target转换成one-hot编码\n",
    "target = torch.eye(3)[target]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "b324a239-b439-4cca-8476-130c0710c33b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 划分训练集和测试集\n",
    "data_train, data_test, target_train, target_test = train_test_split(\n",
    "    data, target, test_size=0.3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "96cab984-e94f-445c-bea5-6c63e4962b2a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义神经网络\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.fc1 = nn.Linear(4, 5)\n",
    "        self.fc2 = nn.Linear(5, 5)\n",
    "        self.fc3 = nn.Linear(5, 3)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return F.softmax(x, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "37a82111-a57a-4681-92a5-df90d5580436",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义训练函数\n",
    "def train(net, X, Y, optimizer, criterion):\n",
    "    net.train()  # 将模型设为训练模式\n",
    "    optimizer.zero_grad()  # 梯度清零，避免累加\n",
    "    output = net(X)  # 前向传播\n",
    "    loss = criterion(output, Y)  # 计算损失\n",
    "    loss.backward()  # 反向传播\n",
    "    optimizer.step()  # 更新模型参数\n",
    "    return loss.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "7dd959a7-9d24-4864-addc-e73e1378cdc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义测试函数\n",
    "def test(net, X, Y):\n",
    "    net.eval()  # 将模型设为测试模式\n",
    "    with torch.no_grad():  # 关闭梯度计算\n",
    "        output = net(X)\n",
    "        pred = torch.argmax(output, dim=1)  # 取预测概率最大的类别\n",
    "        acc = torch.mean((pred == torch.argmax(Y, dim=1)).float())  # 计算准确率\n",
    "    return acc.item()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "76adcd74-3dbe-46e3-acaf-c1444777fe51",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = Net()\n",
    "learning_rate = 0.1\n",
    "criterion = nn.BCELoss()  # 二分类交叉熵损失\n",
    "optimizer = torch.optim.SGD(net.parameters(), lr=learning_rate)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "02b80064-d80a-4e87-8148-2588ef1b4fba",
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 5000\n",
    "train_losses = []\n",
    "test_losses = []\n",
    "test_accs = []"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6e167d32-318a-42a3-b036-3fee797d8426",
   "metadata": {},
   "source": [
    "在代码中先导入了`Sklearn`库中的`load_iris`和`train_test_split`函数，用于载入Iris数据集并将数据集划分为训练集和测试集。然后在定义神经网络时，使用了`nn.Linear`定义了3个全连接层，其中输入层有4个神经元，第一和第二个隐藏层分别有5个神经元，输出层有3个神经元。在训练神经网络时，使用了BCELoss（二分类交叉熵损失）作为损失函数，SGD（随机梯度下降优化器）来优化。每一轮训练后，会计算训练集和测试集的损失和准确率，并输出监控训练进程。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a6e48a4d-f594-422c-9ea1-788cbc6993a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/tmp/ipykernel_112307/1158182061.py:3: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.tensor(target_train, dtype=torch.float32),\n",
      "/tmp/ipykernel_112307/1158182061.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.tensor(target_test, dtype=torch.float32))\n",
      "/tmp/ipykernel_112307/1158182061.py:12: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  torch.tensor(target_test, dtype=torch.float32))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [100/5000], train_loss: 0.3696, test_loss: 0.3488, test_acc: 0.8667\n",
      "Epoch [200/5000], train_loss: 0.2536, test_loss: 0.2405, test_acc: 0.9778\n",
      "Epoch [300/5000], train_loss: 0.1476, test_loss: 0.1203, test_acc: 1.0000\n",
      "Epoch [400/5000], train_loss: 0.1592, test_loss: 0.0993, test_acc: 0.9111\n",
      "Epoch [500/5000], train_loss: 0.1120, test_loss: 0.0582, test_acc: 1.0000\n",
      "Epoch [600/5000], train_loss: 0.0609, test_loss: 0.0402, test_acc: 1.0000\n",
      "Epoch [700/5000], train_loss: 0.0526, test_loss: 0.0400, test_acc: 1.0000\n",
      "Epoch [800/5000], train_loss: 0.0504, test_loss: 0.0400, test_acc: 0.9778\n",
      "Epoch [900/5000], train_loss: 0.0430, test_loss: 0.0365, test_acc: 0.9778\n",
      "Epoch [1000/5000], train_loss: 0.0430, test_loss: 0.0378, test_acc: 0.9778\n",
      "Epoch [1100/5000], train_loss: 0.0391, test_loss: 0.0367, test_acc: 0.9778\n",
      "Epoch [1200/5000], train_loss: 0.3182, test_loss: 0.2735, test_acc: 0.8444\n",
      "Epoch [1300/5000], train_loss: 0.0371, test_loss: 0.0372, test_acc: 0.9778\n",
      "Epoch [1400/5000], train_loss: 0.0354, test_loss: 0.0377, test_acc: 0.9778\n",
      "Epoch [1500/5000], train_loss: 0.0537, test_loss: 0.0087, test_acc: 1.0000\n",
      "Epoch [1600/5000], train_loss: 0.0376, test_loss: 0.0214, test_acc: 1.0000\n",
      "Epoch [1700/5000], train_loss: 0.0381, test_loss: 0.0195, test_acc: 1.0000\n",
      "Epoch [1800/5000], train_loss: 0.0389, test_loss: 0.0173, test_acc: 1.0000\n",
      "Epoch [1900/5000], train_loss: 0.0379, test_loss: 0.0179, test_acc: 1.0000\n",
      "Epoch [2000/5000], train_loss: 0.0373, test_loss: 0.0181, test_acc: 1.0000\n",
      "Epoch [2100/5000], train_loss: 0.0369, test_loss: 0.0182, test_acc: 1.0000\n",
      "Epoch [2200/5000], train_loss: 0.0364, test_loss: 0.0183, test_acc: 1.0000\n",
      "Epoch [2300/5000], train_loss: 0.0360, test_loss: 0.0184, test_acc: 0.9778\n",
      "Epoch [2400/5000], train_loss: 0.0357, test_loss: 0.0184, test_acc: 0.9778\n",
      "Epoch [2500/5000], train_loss: 0.0353, test_loss: 0.0185, test_acc: 0.9778\n",
      "Epoch [2600/5000], train_loss: 0.0351, test_loss: 0.0185, test_acc: 0.9778\n",
      "Epoch [2700/5000], train_loss: 0.0348, test_loss: 0.0185, test_acc: 0.9778\n",
      "Epoch [2800/5000], train_loss: 0.0345, test_loss: 0.0186, test_acc: 0.9778\n",
      "Epoch [2900/5000], train_loss: 0.0343, test_loss: 0.0186, test_acc: 0.9778\n",
      "Epoch [3000/5000], train_loss: 0.0341, test_loss: 0.0187, test_acc: 0.9778\n",
      "Epoch [3100/5000], train_loss: 0.0338, test_loss: 0.0187, test_acc: 0.9778\n",
      "Epoch [3200/5000], train_loss: 0.0336, test_loss: 0.0187, test_acc: 0.9778\n",
      "Epoch [3300/5000], train_loss: 0.0335, test_loss: 0.0188, test_acc: 0.9778\n",
      "Epoch [3400/5000], train_loss: 0.0333, test_loss: 0.0188, test_acc: 0.9778\n",
      "Epoch [3500/5000], train_loss: 0.0331, test_loss: 0.0189, test_acc: 0.9778\n",
      "Epoch [3600/5000], train_loss: 0.0329, test_loss: 0.0189, test_acc: 0.9778\n",
      "Epoch [3700/5000], train_loss: 0.0328, test_loss: 0.0190, test_acc: 0.9778\n",
      "Epoch [3800/5000], train_loss: 0.0326, test_loss: 0.0190, test_acc: 0.9778\n",
      "Epoch [3900/5000], train_loss: 0.0324, test_loss: 0.0191, test_acc: 0.9778\n",
      "Epoch [4000/5000], train_loss: 0.0323, test_loss: 0.0192, test_acc: 0.9778\n",
      "Epoch [4100/5000], train_loss: 0.0321, test_loss: 0.0193, test_acc: 0.9778\n",
      "Epoch [4200/5000], train_loss: 0.0320, test_loss: 0.0193, test_acc: 0.9778\n",
      "Epoch [4300/5000], train_loss: 0.0318, test_loss: 0.0194, test_acc: 0.9778\n",
      "Epoch [4400/5000], train_loss: 0.0317, test_loss: 0.0195, test_acc: 0.9778\n",
      "Epoch [4500/5000], train_loss: 0.0316, test_loss: 0.0196, test_acc: 0.9778\n",
      "Epoch [4600/5000], train_loss: 0.0314, test_loss: 0.0197, test_acc: 0.9778\n",
      "Epoch [4700/5000], train_loss: 0.0313, test_loss: 0.0198, test_acc: 0.9778\n",
      "Epoch [4800/5000], train_loss: 0.0312, test_loss: 0.0198, test_acc: 0.9778\n",
      "Epoch [4900/5000], train_loss: 0.0310, test_loss: 0.0201, test_acc: 0.9778\n",
      "Epoch [5000/5000], train_loss: 0.0309, test_loss: 0.0201, test_acc: 0.9778\n"
     ]
    }
   ],
   "source": [
    "for i in range(epochs):\n",
    "    train_loss = train(net, torch.tensor(data_train, dtype=torch.float32),\n",
    "                       torch.tensor(target_train, dtype=torch.float32),\n",
    "                       optimizer, criterion)\n",
    "    train_losses.append(train_loss)\n",
    "\n",
    "    test_loss = criterion(net(torch.tensor(data_test, dtype=torch.float32)),\n",
    "                          torch.tensor(target_test, dtype=torch.float32))\n",
    "    test_losses.append(test_loss.item())\n",
    "\n",
    "    test_acc = test(net, torch.tensor(data_test, dtype=torch.float32),\n",
    "                    torch.tensor(target_test, dtype=torch.float32))\n",
    "    test_accs.append(test_acc)\n",
    "\n",
    "    if (i+1) % 100 == 0:\n",
    "        print('Epoch [{}/{}], train_loss: {:.4f}, test_loss: {:.4f}, test_acc: {:.4f}'\n",
    "              .format(i+1, epochs, train_loss, test_loss.item(), test_acc))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc96b100-c4c1-4a45-9fff-b53095abc66c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
